{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c48812ed-35bd-4fbe-9a2c-6c7335e5645e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_anthropic import ChatAnthropic\n",
    "from langchain_core.runnables import ConfigurableField\n",
    "from langchain_core.tools import tool\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "\n",
    "@tool\n",
    "def multiply(x: float, y: float) -> float:\n",
    "    \"\"\"Multiply 'x' times 'y'.\"\"\"\n",
    "    return x * y\n",
    "\n",
    "\n",
    "@tool\n",
    "def exponentiate(x: float, y: float) -> float:\n",
    "    \"\"\"Raise 'x' to the 'y'.\"\"\"\n",
    "    return x**y\n",
    "\n",
    "\n",
    "@tool\n",
    "def add(x: float, y: float) -> float:\n",
    "    \"\"\"Add 'x' and 'y'.\"\"\"\n",
    "    return x + y\n",
    "\n",
    "\n",
    "tools = [multiply, exponentiate, add]\n",
    "\n",
    "gpt35 = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0).bind_tools(tools)\n",
    "claude3 = ChatAnthropic(model=\"claude-3-sonnet-20240229\").bind_tools(tools)\n",
    "llm_with_tools = gpt35.configurable_alternatives(\n",
    "    ConfigurableField(id=\"llm\"), default_key=\"gpt35\", claude3=claude3\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c186263-1b98-4cb2-b6d1-71f65eb0d811",
   "metadata": {},
   "source": [
    "# LangGraph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "28fc2c60-7dbc-428a-8983-1a6a15ea30d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import operator\n",
    "from typing import Annotated, Sequence, TypedDict\n",
    "\n",
    "from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage\n",
    "from langchain_core.runnables import RunnableLambda\n",
    "from langgraph.graph import END, StateGraph\n",
    "\n",
    "\n",
    "class AgentState(TypedDict):\n",
    "    messages: Annotated[Sequence[BaseMessage], operator.add]\n",
    "\n",
    "\n",
    "def should_continue(state):\n",
    "    return \"continue\" if state[\"messages\"][-1].tool_calls else \"end\"\n",
    "\n",
    "\n",
    "def call_model(state, config):\n",
    "    return {\"messages\": [llm_with_tools.invoke(state[\"messages\"], config=config)]}\n",
    "\n",
    "\n",
    "def _invoke_tool(tool_call):\n",
    "    tool = {tool.name: tool for tool in tools}[tool_call[\"name\"]]\n",
    "    return ToolMessage(tool.invoke(tool_call[\"args\"]), tool_call_id=tool_call[\"id\"])\n",
    "\n",
    "\n",
    "tool_executor = RunnableLambda(_invoke_tool)\n",
    "\n",
    "\n",
    "def call_tools(state):\n",
    "    last_message = state[\"messages\"][-1]\n",
    "    return {\"messages\": tool_executor.batch(last_message.tool_calls)}\n",
    "\n",
    "\n",
    "workflow = StateGraph(AgentState)\n",
    "workflow.add_node(\"agent\", call_model)\n",
    "workflow.add_node(\"action\", call_tools)\n",
    "workflow.set_entry_point(\"agent\")\n",
    "workflow.add_conditional_edges(\n",
    "    \"agent\",\n",
    "    should_continue,\n",
    "    {\n",
    "        \"continue\": \"action\",\n",
    "        \"end\": END,\n",
    "    },\n",
    ")\n",
    "workflow.add_edge(\"action\", \"agent\")\n",
    "graph = workflow.compile()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3710e724-2595-4625-ba3a-effb81e66e4a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'messages': [HumanMessage(content=\"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"),\n",
       "  AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_6yMU2WsS4Bqgi1WxFHxtfJRc', 'function': {'arguments': '{\"x\": 8, \"y\": 2.743}', 'name': 'exponentiate'}, 'type': 'function'}, {'id': 'call_GAL3dQiKFF9XEV0RrRLPTvVp', 'function': {'arguments': '{\"x\": 17.24, \"y\": -918.1241}', 'name': 'add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 58, 'prompt_tokens': 168, 'total_tokens': 226}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-528302fc-7acf-4c11-82c4-119ccf40c573-0', tool_calls=[{'name': 'exponentiate', 'args': {'x': 8, 'y': 2.743}, 'id': 'call_6yMU2WsS4Bqgi1WxFHxtfJRc'}, {'name': 'add', 'args': {'x': 17.24, 'y': -918.1241}, 'id': 'call_GAL3dQiKFF9XEV0RrRLPTvVp'}]),\n",
       "  ToolMessage(content='300.03770462067547', tool_call_id='call_6yMU2WsS4Bqgi1WxFHxtfJRc'),\n",
       "  ToolMessage(content='-900.8841', tool_call_id='call_GAL3dQiKFF9XEV0RrRLPTvVp'),\n",
       "  AIMessage(content='The result of \\\\(3 + 5^{2.743}\\\\) is approximately 300.04, and the result of \\\\(17.24 - 918.1241\\\\) is approximately -900.88.', response_metadata={'token_usage': {'completion_tokens': 44, 'prompt_tokens': 251, 'total_tokens': 295}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'stop', 'logprobs': None}, id='run-d1161669-ed09-4b18-94bd-6d8530df5aa8-0')]}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "graph.invoke(\n",
    "    {\n",
    "        \"messages\": [\n",
    "            HumanMessage(\n",
    "                \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"\n",
    "            )\n",
    "        ]\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "073c074e-d722-42e0-85ec-c62c079207e4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'messages': [HumanMessage(content=\"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"),\n",
       "  AIMessage(content=[{'text': \"Okay, let's break this down into two parts:\", 'type': 'text'}, {'id': 'toolu_01DEhqcXkXTtzJAiZ7uMBeDC', 'input': {'x': 3, 'y': 5}, 'name': 'add', 'type': 'tool_use'}], response_metadata={'id': 'msg_01AkLGH8sxMHaH15yewmjwkF', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 450, 'output_tokens': 81}}, id='run-f35bfae8-8ded-4f8a-831b-0940d6ad16b6-0', tool_calls=[{'name': 'add', 'args': {'x': 3, 'y': 5}, 'id': 'toolu_01DEhqcXkXTtzJAiZ7uMBeDC'}]),\n",
       "  ToolMessage(content='8.0', tool_call_id='toolu_01DEhqcXkXTtzJAiZ7uMBeDC'),\n",
       "  AIMessage(content=[{'id': 'toolu_013DyMLrvnrto33peAKMGMr1', 'input': {'x': 8.0, 'y': 2.743}, 'name': 'exponentiate', 'type': 'tool_use'}], response_metadata={'id': 'msg_015Fmp8aztwYcce2JDAFfce3', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 545, 'output_tokens': 75}}, id='run-48aaeeeb-a1e5-48fd-a57a-6c3da2907b47-0', tool_calls=[{'name': 'exponentiate', 'args': {'x': 8.0, 'y': 2.743}, 'id': 'toolu_013DyMLrvnrto33peAKMGMr1'}]),\n",
       "  ToolMessage(content='300.03770462067547', tool_call_id='toolu_013DyMLrvnrto33peAKMGMr1'),\n",
       "  AIMessage(content=[{'text': 'So 3 plus 5 raised to the 2.743 power is 300.04.\\n\\nFor the second part:', 'type': 'text'}, {'id': 'toolu_01UTmMrGTmLpPrPCF1rShN46', 'input': {'x': 17.24, 'y': -918.1241}, 'name': 'add', 'type': 'tool_use'}], response_metadata={'id': 'msg_015TkhfRBENPib2RWAxkieH6', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 638, 'output_tokens': 105}}, id='run-45fb62e3-d102-4159-881d-241c5dbadeed-0', tool_calls=[{'name': 'add', 'args': {'x': 17.24, 'y': -918.1241}, 'id': 'toolu_01UTmMrGTmLpPrPCF1rShN46'}]),\n",
       "  ToolMessage(content='-900.8841', tool_call_id='toolu_01UTmMrGTmLpPrPCF1rShN46'),\n",
       "  AIMessage(content='Therefore, 17.24 - 918.1241 = -900.8841', response_metadata={'id': 'msg_01LgKnRuUcSyADCpxv9tPoYD', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'end_turn', 'stop_sequence': None, 'usage': {'input_tokens': 759, 'output_tokens': 24}}, id='run-1008254e-ccd1-497c-8312-9550dd77bd08-0')]}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "graph.invoke(\n",
    "    {\n",
    "        \"messages\": [\n",
    "            HumanMessage(\n",
    "                \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"\n",
    "            )\n",
    "        ]\n",
    "    },\n",
    "    config={\"configurable\": {\"llm\": \"claude3\"}},\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
